import os
import cv2
import json
import itertools
import numpy as np
import os.path as osp
from collections import deque

from dust3r.utils.image import imread_cv2
from .base_many_view_dataset import BaseManyViewDataset


class Co3d(BaseManyViewDataset):
    def __init__(self, mask_bg=True, use_comb=True,
                 scene_class=None, scene_id=None,
                 num_seq=100, num_frames=5, 
                 min_thresh=5, max_thresh=20,
                 full_video=False, lb=0, ub=30,
                 kf_every=1, *args, ROOT, **kwargs):
        self.ROOT = ROOT
        super().__init__(*args, **kwargs)

        assert mask_bg in (True, False, 'rand')
        self.mask_bg = mask_bg
    
        self.num_seq = num_seq
        self.num_frames = num_frames
        self.max_thresh = max_thresh
        self.min_thresh = min_thresh

        self.full_video = full_video
        self.kf_every = kf_every
        self.use_comb = use_comb

        self.scenes, self.scene_list = self.load_scene(scene_class, scene_id)

        self.combinations, self.num_seq = self.get_combinations(use_comb, lb, ub)
        self.invalidate = {scene: {} for scene in self.scene_list}
    

    def get_combinations(self, use_comb, lb, ub):

        if use_comb and not self.full_video:
            print('Using combinations')
            combinations = list(itertools.combinations(range(100), self.num_frames))
            combinations = [combo for combo in combinations if all(lb < abs(x-y) <= ub and abs(x-y) % 5 == 0 for x, y in zip(combo, combo[1:]))]
            num_seq = len(combinations)
            print('Number of sequences:', num_seq)
        else:
            combinations = None
            num_seq = self.num_seq
        
        return combinations, num_seq

    

    def load_scene(self, scene_class=None, scene_id=None):
        print('Loading scenes')
        with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
            scenes = json.load(f)
            
            if scene_class is not None:
                scenes = {k: v for k, v in scenes.items() if k == scene_class}
            else:
                scenes = {k: v for k, v in scenes.items() if len(v) > 0} # k is class (apple), v is corresponding list
            
            if scene_id is not None:
                scenes = {(k, k2): v2 for k, v in scenes.items() for k2, v2 in v.items() if k2 == scene_id}
            else:
                scenes = {(k, k2): v2 for k, v in scenes.items() 
                            for k2, v2 in v.items()} # k is class (apple), k2 is instance (110_13051_23361), v2 is list of image idx
        scene_list = list(scenes.keys())
        
        return scenes, scene_list
    
    def __len__(self):

        return len(self.scene_list) * self.num_seq
    
    def _get_views(self, idx, resolution, rng, attempts=0):
        obj, instance = self.scene_list[idx // self.num_seq]
        image_pool = self.scenes[obj, instance]

        if self.use_comb and not self.full_video:
            frame_idx = self.combinations[idx % len(self.combinations)]

            last = len(image_pool)-1
            imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in frame_idx]

        else:
            img_idx = range(0, len(image_pool))
            imgs_idxs = self.sample_frames(img_idx, rng)
        

        if resolution not in self.invalidate[obj, instance]:  # flag invalid images
            self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
        
        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
       
        imgs_idxs = deque(imgs_idxs)

        max_depth_min = 1e8
        max_depth_max = 0.      
        max_depth_first = None  
        

        views = []

        while len(imgs_idxs) > 0: 
            im_idx = imgs_idxs.popleft()

            if self.invalidate[obj, instance][resolution][im_idx]:
                # search for a valid image
                random_direction = 2 * rng.choice(2) - 1
                for offset in range(1, len(image_pool)):
                    tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
                    if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
                        im_idx = tentative_im_idx
                        break

            view_idx = image_pool[im_idx]

            impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06d}.jpg')

            # load camera params
            input_metadata = np.load(impath.replace('jpg', 'npz'))
            camera_pose = input_metadata['camera_pose'].astype(np.float32)
            intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)

            rgb_image = imread_cv2(impath)
            depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
            depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])

            if mask_bg:
                # load object mask
                maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06d}.png')
                maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
                maskmap = (maskmap / 255.0) > 0.1

                # update the depthmap with mask
                depthmap *= maskmap
                            
            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)

            num_valid = (depthmap > 0.0).sum()
            if num_valid == 0:
                # problem, invalidate image and retry
                self.invalidate[obj, instance][resolution][im_idx] = True
                imgs_idxs.appendleft(im_idx)
                continue

            if input_metadata['maximum_depth'] > max_depth_max:
                max_depth_max = input_metadata['maximum_depth']
            
            if input_metadata['maximum_depth'] < max_depth_min:
                max_depth_min = input_metadata['maximum_depth']
            
            if max_depth_first is None:
                max_depth_first = input_metadata['maximum_depth']
            
            
            
            views.append(dict(
                img=rgb_image,
                depthmap=depthmap,
                camera_pose=camera_pose,
                camera_intrinsics=intrinsics,
                dataset='Co3d_v2',
                label=osp.join(obj, instance),
                instance=osp.split(impath)[1]
            ))

        if max_depth_max / max_depth_min > 100. or max_depth_max / max_depth_first > 10.:
            new_idx = rng.integers(0, self.__len__()-1)
            return self._get_views(new_idx, resolution, rng)
        
        return views







    


